import torch

a = torch.zeros((2, 4))
b = torch.ones((2, 4))

out_1 = torch.cat((a, b), dim=0)
out_2 = torch.cat((a, b), dim=1)

print(out_1)
print(out_2)

# tensor([[0., 0., 0., 0.],
#        [0., 0., 0., 0.],
#        [1., 1., 1., 1.],
#        [1., 1., 1., 1.]])


# tensor([[0., 0., 0., 0., 1., 1., 1., 1.],
#        [0., 0., 0., 0., 1., 1., 1., 1.]])

# a = torch.linspace(1, 6, 6).view(2, 3)
# b = torch.linspace(7, 12, 6).view(2, 3)
# print(a)
# print(b)
# out = torch.stack((a, b), dim=0)
# print(out)
# print(out.shape)